/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
*   http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
#include <thrift/thrift-config.h>

#include <thrift/transport/TSSLContext.h>
#include <thrift/transport/TSSLSocket.h>

#include <openssl/err.h>
#include <openssl/rand.h>
#include <openssl/ssl.h>
#include <openssl/x509v3.h>
#include <boost/foreach.hpp>
#include <boost/bind.hpp>

namespace apache {
namespace thrift {
namespace transport {

// SSLContext implementation
TSSLContext::TSSLContext(SSLVersion version)
  : verifyMode_(boost::asio::ssl::verify_none)
{
  boost::asio::ssl::context::options options;
  boost::asio::ssl::context::method protover;
  
  switch (version) {
    case TLSv1:
      protover = boost::asio::ssl::context::tlsv1;
      options = boost::asio::ssl::context::no_sslv2 | boost::asio::ssl::context::no_sslv3;
      break;
    case SSLv3:
      protover = boost::asio::ssl::context::sslv3;
      options = boost::asio::ssl::context::no_sslv2;
      break;
    default:
      protover = boost::asio::ssl::context::sslv23;
      options = boost::asio::ssl::context::no_sslv2;
      break;
  }
  options |= boost::asio::ssl::context::default_workarounds;
  ctx_ = boost::shared_ptr<boost::asio::ssl::context>(new boost::asio::ssl::context(protover));
  ctx_->set_options(options);

  SSL_CTX_set_mode(ctx_->native_handle(), SSL_MODE_AUTO_RETRY);

  checkPeerName_ = false;
}

TSSLContext::~TSSLContext()
{

}

void TSSLContext::ciphers(const std::string& ciphers)
{
  providedCiphersString_ = ciphers;
  setCiphersOrThrow(ciphers);
}

void TSSLContext::setCiphersOrThrow(const std::string& ciphers)
{
  int rc = SSL_CTX_set_cipher_list(ctx_->native_handle(), ciphers.c_str());
  if (ERR_peek_error() != 0) {
    throw TSSLException("SSL_CTX_set_cipher_list: " + getSSLErrors());
  }
  if (rc == 0) {
    throw TSSLException("None of specified ciphers are supported");
  }
}

boost::asio::ssl::verify_mode TSSLContext::getVerificationMode()
{
  return verifyMode_;
}

void TSSLContext::authenticate(bool checkPeerCert, bool checkPeerName, const std::string& peerName)
{
  boost::asio::ssl::verify_mode mode;
  if (checkPeerCert) {
    mode = boost::asio::ssl::verify_peer |
           boost::asio::ssl::verify_fail_if_no_peer_cert |
           boost::asio::ssl::verify_client_once;
    checkPeerName_ = checkPeerName;
    peerFixedName_ = peerName;
  }
  else {
    mode = boost::asio::ssl::verify_none;
    checkPeerName_ = false; // can't check name without cert!
    peerFixedName_.clear();
  }
  boost::system::error_code ec;
  ctx_->set_verify_mode(mode, ec);
  verifyMode_ = mode;
  if (ec) {
    throw TSSLException("SSL_CTX_set_verify: " + ec.message());
  }
}

void TSSLContext::loadCertificate(const std::string& path, const std::string& format)
{
  if (format == "PEM") {
    boost::system::error_code ec;
    ctx_->use_certificate_chain_file(path, ec);
    if (ec) {
      throw TSSLException("SSL_CTX_use_certificate_chain_file: " + ec.message());
    }
  }
  else {
    throw TSSLException("Unsupported certificate format: " + format);
  }
}

void TSSLContext::loadPrivateKey(const std::string& path, const std::string& format)
{
  if (format == "PEM") {
    boost::system::error_code ec;
    ctx_->use_private_key_file(path, boost::asio::ssl::context::pem, ec);
    if (ec) {
      throw TSSLException("SSL_CTX_use_PrivateKey_file: " + ec.message());
    }
  }
}

void TSSLContext::loadTrustedCertificates(const std::string& path)
{
  boost::system::error_code ec;
  ctx_->load_verify_file(path, ec);
  if (ec) {
    throw TSSLException("SSL_CTX_load_verify_locations: " + ec.message());
  }
}

void TSSLContext::loadTrustedCertificates(X509_STORE* store)
{
  SSL_CTX_set_cert_store(ctx_->native_handle(), store);
}

void TSSLContext::loadClientCAList(const std::string& path)
{
  STACK_OF(X509_NAME)* clientCAs = SSL_load_client_CA_file(path.c_str());
  if (clientCAs == NULL) {
    TLogging::log_err("Unable to load ca file: %s", path.c_str());
    return;
  }
  SSL_CTX_set_client_CA_list(ctx_->native_handle(), clientCAs);
}

void TSSLContext::passwordCollector(const boost::shared_ptr<PasswordCollector>& collector)
{
  collector_ = collector;
  ctx_->set_password_callback(boost::bind(&TSSLContext::passwordCallback, this, _1, _2));
}

/**
 * Match a name with a pattern. The pattern may include wildcard. A single
 * wildcard "*" can match up to one component in the domain name.
 *
 * @param  host    Host name, typically the name of the remote host
 * @param  pattern Name retrieved from certificate
 * @param  size    Size of "pattern"
 * @return True, if "host" matches "pattern". False otherwise.
 */
bool TSSLContext::matchName(const char* host, const char* pattern, int size)
{
  bool match = false;
  int i = 0, j = 0;
  while (i < size && host[j] != '\0') {
    if (toupper(pattern[i]) == toupper(host[j])) {
      i++;
      j++;
      continue;
    }
    if (pattern[i] == '*') {
      while (host[j] != '.' && host[j] != '\0') {
        j++;
      }
      i++;
      continue;
    }
    break;
  }
  if (i == size && host[j] == '\0') {
    match = true;
  }
  return match;
}

std::string TSSLContext::passwordCallback(size_t size, boost::asio::ssl::context::password_purpose purpose)
{
  if (!(this->passwordCollector())) {
    return "";
  }
  std::string userPassword;
  this->passwordCollector()->getPassword(userPassword, size);
  int length = userPassword.size();
  if (length > size) {
    return userPassword.substr(0, size);
  }
  return userPassword;
}

std::string TSSLContext::getSSLErrors()
{
  std::string errors;
  unsigned long  errorCode;
  errors.reserve(512);
  while ((errorCode = ERR_get_error()) != 0) {
    if (!errors.empty()) {
      errors += "; ";
    }
    const char* reason = ERR_reason_error_string(errorCode);
    if (reason == NULL) {
      errors += std::string("SSL error # ") + apache::thrift::to_string(errorCode);
    }
    else {
      errors += reason;
    }
  }
  if (errors.empty()) {
    errors = "error unknown";
  }
  return errors;
}

std::ostream&
operator<<(std::ostream& os, const PasswordCollector& collector)
{
  os << collector.describe();
  return os;
}

}
}
}
